#coding=utf-8

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import os
import time
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.utils.data as data
import numpy as np
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
import apex
import torch.multiprocessing as mp

from data.config import cur_config as cfg
from layers.modules import MultiBoxLoss
from data.widerface import WIDERDetection, detection_collate
from models.factory import build_net, basenet_factory

parser = argparse.ArgumentParser(
    description='DSFD face Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--batch_size',
                    default=16, type=int,
                    help='Batch size for training')
parser.add_argument('--model',
                    default='vgg', type=str,
                    choices=['vgg', 'resnet50', 'resnet101', 'resnet152'],
                    help='model for training')
parser.add_argument('--resume',
                    default=None, type=str,
                    help='Checkpoint state_dict file to resume training from')
parser.add_argument('--num_workers',
                    default=4, type=int,
                    help='Number of workers used in dataloading')
parser.add_argument('--npu',
                    default=True, type=bool,
                    help='Use npu to train model')
parser.add_argument('--lr', '--learning-rate',
                    default=1e-3, type=float,
                    help='initial learning rate')
parser.add_argument('--momentum',
                    default=0.9, type=float,
                    help='Momentum value for optim')
parser.add_argument('--weight_decay',
                    default=5e-4, type=float,
                    help='Weight decay for SGD')
parser.add_argument('--gamma',
                    default=0.1, type=float,
                    help='Gamma update for SGD')
parser.add_argument('--multigpu',
                    default=False, type=bool,
                    help='Use mutil Gpu training')
parser.add_argument("--dist_url", help="", default='127.0.0.1:6667', type=str)
parser.add_argument('--nodes', default=1, type=int, metavar='N')
parser.add_argument('--nr', default=0, type=int, help='ranking within the nodes')
parser.add_argument('--npus', default=8, type=int,help='number of gpus per node')
parser.add_argument('--device_id', default=0, type=int)

parser.add_argument('--save_folder',
                    default='weights/',
                    help='Directory for saving checkpoint models')
parser.add_argument('--pretrain_weight',
                    default='./pretrain_weights/',
                    help='Directory for pretrained checkpoint models')

args = parser.parse_args()

args.is_master_node = not args.multigpu or args.device_id == 0

print("p 0")

save_folder = os.path.join(args.save_folder, args.model)
if not os.path.exists(save_folder):
    os.mkdir(save_folder)

train_dataset = WIDERDetection(cfg.FACE_TRAIN_FILE, mode='train')
val_dataset = WIDERDetection(cfg.FACE_VAL_FILE, mode='val')

val_batchsize = args.batch_size // 2
val_loader = data.DataLoader(val_dataset, val_batchsize,
                             num_workers=args.num_workers,
                             shuffle=False,
                             collate_fn=detection_collate,
                             pin_memory=True)
min_loss = np.inf

def train(args):
    print("enter train")
    per_epoch_size = len(train_dataset) // args.batch_size
    start_epoch = 0
    iteration = 0
    step_index = 0

    # basenet = basenet_factory(args.model)
    # dsfd_net = build_net('train', cfg.NUM_CLASSES, args.model)
    # net = dsfd_net

    args.device = torch.device(f'npu:{args.device_id}')
    torch.npu.set_device(args.device)

    if args.multigpu:
        print("multi NPU")
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '29688'
        args.world_size = args.npus * args.nodes  # args.gpus 每个结点的GPU数， args.nodes使用的结点数
        #args.rank = args.nr * ngpus_per_node + args.device_id  # args.nr 是当前结点的阶序rank，取值范围是 0 到 args.nodes - 1.
        torch.distributed.init_process_group(backend='hccl',
                                             init_method='env://',
                                             world_size=args.world_size,
                                             #world_size = 2,
                                             rank=args.device_id)  # GPU nccl NPU hccl
        print(f'all of nodes: {args.world_size}, cur rank: {args.device_id}')
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=args.world_size,
                                                                        rank=args.device_id)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       pin_memory=False,
                                       num_workers=8,
                                       sampler=train_sampler,
                                       collate_fn=detection_collate, #需要的
                                       drop_last=False) #drop_last 置为 True
    else:
        train_loader = data.DataLoader(train_dataset, args.batch_size,
                                       num_workers=args.num_workers,
                                       shuffle=True,
                                       collate_fn=detection_collate,
                                       pin_memory=True)

    basenet = basenet_factory(args.model)
    dsfd_net = build_net('train', cfg.NUM_CLASSES, args.model) # net init after DataLoader
    net = dsfd_net

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        start_epoch = net.load_weights(args.resume)
        iteration = start_epoch * per_epoch_size
    else:
        base_weights = torch.load(args.pretrain_weight + basenet) #加载预训练权重
        print('Load base network {}'.format(args.pretrain_weight + basenet))
        if args.model == 'vgg':
            net.vgg.load_state_dict(base_weights)
        else:
            #VIVID  更新加载预训练权重结构
            #net.resnet.load_state_dict(base_weights)
            pretrained_dict = base_weights
            model_dict = net.resnet.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            #print(model_dict)
            net.resnet.load_state_dict(model_dict)


    criterion = MultiBoxLoss(cfg, args.npu, args.device)
    print('Loading wider dataset...')
    print('Using the specified args:')
    print(args)

    net = net.to(args.device)

    # optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = apex.optimizers.NpuFusedSGD(net.parameters(),
                                            lr=args.lr,
                                            momentum=args.momentum,
                                            weight_decay=args.weight_decay)

    #dsfd_net, optimizer = amp.initialize(net, optimizer, opt_level="O1", loss_scale=None) #不收敛改为O1
    net, optimizer = amp.initialize(net, optimizer, opt_level="O1", loss_scale=None, combine_grad=True) #不收敛改为O1

    if args.npu:
        if args.multigpu:
            #net = torch.nn.DataParallel(dsfd_net, device_ids = cfg.MultiGPU_ID) #为适配NPU 修改为 DistributedDataParallel #分布式验证通过

            net = torch.nn.parallel.DistributedDataParallel(net, device_ids = [args.device_id], broadcast_buffers=False)
    if not args.resume:
        print('Initializing weights...')
        dsfd_net.extras.apply(dsfd_net.weights_init)
        dsfd_net.fpn_topdown.apply(dsfd_net.weights_init)
        dsfd_net.fpn_latlayer.apply(dsfd_net.weights_init)
        dsfd_net.fpn_fem.apply(dsfd_net.weights_init)
        dsfd_net.loc_pal1.apply(dsfd_net.weights_init)
        dsfd_net.conf_pal1.apply(dsfd_net.weights_init)
        dsfd_net.loc_pal2.apply(dsfd_net.weights_init)
        dsfd_net.conf_pal2.apply(dsfd_net.weights_init)

    for step in cfg.LR_STEPS:
        if iteration > step:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index) #进行学习率调整

    net.train()
    for epoch in range(start_epoch, cfg.EPOCHES):
        #train_sampler.set_epoch(epoch)
        losses = 0
        if args.multigpu:
            #train_loader.sampler.set_epoch(epoch)
            train_loader.sampler.set_epoch(epoch)
            #train_sampler.set_epoch(epoch)
        print("get train_loader", train_loader)
        for batch_idx, (images, targets) in enumerate(train_loader):
            if args.npu:
                images = Variable(images.to(args.device))
                targets = [Variable(ann.to(args.device), volatile=True)
                           for ann in targets]
            else:
                images = Variable(images)
                targets = [Variable(ann, volatile=True) for ann in targets]

            if iteration in cfg.LR_STEPS:
                step_index += 1
                adjust_learning_rate(optimizer, args.gamma, step_index)

            t0 = time.time()
            out = net(images)
            print("get net out")
            # backprop
            optimizer.zero_grad()
            loss_l_pa1l, loss_c_pal1 = criterion(out[:3], targets)
            loss_l_pa12, loss_c_pal2 = criterion(out[3:], targets)

            loss = loss_l_pa1l + loss_c_pal1 + loss_l_pa12 + loss_c_pal2
            #APEX
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            optimizer.step()
            t1 = time.time()
            #VIVID  losses += loss.data[0]修改为 losses += loss.data[0].item()
            losses += loss.data.item()
            print("get losses", losses)

            if iteration % 10 == 0:
                tloss = losses / (batch_idx + 1)
                if args.is_master_node:
                    with open ("train_npu_8p.txt", "a") as f:
                        f.write('Timer: %.4f' % (t1 - t0) + '\n')
                        f.write('epoch:' + repr(epoch) + ' || iter:' + repr(iteration) + ' || Loss:%.4f' % (tloss) + '\n')
                        f.write('->> pal1 conf loss:{:.4f} || pal1 loc loss:{:.4f}'.format(loss_c_pal1.data.item(), loss_l_pa1l.data.item()) + '\n')
                        f.write('->> pal2 conf loss:{:.4f} || pal2 loc loss:{:.4f}'.format(loss_c_pal2.data.item(), loss_l_pa12.data.item()) +'\n')
                        f.write('->>lr:{}'.format(optimizer.param_groups[0]['lr']) + '\n')

                print('Timer: %.4f' % (t1 - t0))
                print('epoch:' + repr(epoch) + ' || iter:' +
                      repr(iteration) + ' || Loss:%.4f' % (tloss))
                print('->> pal1 conf loss:{:.4f} || pal1 loc loss:{:.4f}'.format(
                    loss_c_pal1.data.item(), loss_l_pa1l.data.item()))
                print('->> pal2 conf loss:{:.4f} || pal2 loc loss:{:.4f}'.format(
                    loss_c_pal2.data.item(), loss_l_pa12.data.item()))
                print('->>lr:{}'.format(optimizer.param_groups[0]['lr']))

            if iteration != 0 and iteration % 500 == 0:
                print('Saving state, iter:', iteration)
                file = 'dsfd_' + repr(iteration) + '.pth'
                #file2 = 'dsfd_model_' + repr(iteration) + '.pth'
                #torch.save(dsfd_net.state_dict(),os.path.join(save_folder, file))
                if args.is_master_node:
                    #torch.save(dsfd_net.module.state_dict(),os.path.join(save_folder, file)) #报错 AttributeError: 'DSFD' object has no attribute 'module'
                    torch.save(dsfd_net.state_dict(),os.path.join(save_folder, file))
                #torch.save(dsfd_net, os.path.join(save_folder, file2))
                #torch.save(dsfd_net, os.path.join(save_folder, file)) #保存整个模型
            iteration += 1

        val(epoch, net, dsfd_net, criterion) #进行验证
        if iteration == cfg.MAX_STEPS:
            break


def val(epoch, net, dsfd_net, criterion):
    net.eval()
    step = 0
    losses = 0
    t1 = time.time()
    for batch_idx, (images, targets) in enumerate(val_loader):
        if args.npu:
            images = Variable(images.npu())
            targets = [Variable(ann.npu(), volatile=True)
                       for ann in targets]
        else:
            images = Variable(images)
            targets = [Variable(ann, volatile=True) for ann in targets]

        out = net(images)
        loss_l_pa1l, loss_c_pal1 = criterion(out[:3], targets)
        loss_l_pa12, loss_c_pal2 = criterion(out[3:], targets)
        loss = loss_l_pa12 + loss_c_pal2
        #losses += loss.data[0]
        losses += loss.data.item()
        step += 1

    tloss = losses / step
    t2 = time.time()
    print('Timer: %.4f' % (t2 - t1))
    print('test epoch:' + repr(epoch) + ' || Loss:%.4f' % (tloss))

    global min_loss
    if tloss < min_loss:
        print('Saving best state,epoch', epoch)
        torch.save(dsfd_net.state_dict(), os.path.join(save_folder, 'dsfd.pth'))
        #torch.save(dsfd_net, os.path.join(save_folder, 'dsfd_model.pth'))
        min_loss = tloss

    states = {
        'epoch': epoch,
        'weight': dsfd_net.state_dict(),
    }
    torch.save(states, os.path.join(save_folder, 'dsfd_checkpoint.pth'))

def adjust_learning_rate(optimizer, gamma, step):
    lr = args.lr * (gamma ** (step))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

if __name__ == '__main__':
    print("main in")
    train(args)
    #mp.spawn(train)
    #train()
